package java;

public class Question18 {
    public static void main(String[] args) {
        int n = 2;
        ListNode head = new ListNode(1);

        ListNode root = removeNthFromEnd(head,n);
    }
    public static ListNode removeNthFromEnd(ListNode head, int n) {
        if (head == null || n < 0){
            return head;
        }

        ListNode current = head;
        for (int i = 0; i < n;i ++){
            if (current.next != null){
                current = current.next;
            }else {
                if (i == n - 1){
                    return head.next;
                }
                return head;
            }
        }
        ListNode removeNextNode = head;
        while (current.next != null){
            current = current.next;
            removeNextNode = removeNextNode.next;
        }
        removeNextNode.next = removeNextNode.next.next;
        return head;
    }

}

class ListNode {
    int val;
    ListNode next;
    ListNode(int x) { val = x; }
}